{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. 问题陈述\n",
    "在线性回归问题中，你可以得到很多的数据点，然后你需要使用一条直线去拟合这些离散点。在这个例子中，我们创建了100个离散点，然后用一条直线去拟合它们。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. 创建训练数据\n",
    "TrainX 的数据范围是 -1 到 1，TrainY 与 TrainX 的关系是3倍，并且我们加入了一些噪声点。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import keras\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Dense\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "\n",
    "w = 3\n",
    "b = 1.25\n",
    "sigma = 0.5\n",
    "trX = np.linspace(-1, 1, 101)\n",
    "trY = w*trX + b + np.random.randn(*trX.shape) * sigma"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x21dbe45f948>]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAWfElEQVR4nO3df6xkZX3H8c/37rJVFHVl10JZ7q4EQ1qMrXBDt9JUiNYiRa1tTbHU2qq5McFEYn9pTdDgP7UNprZuf9wiqbQE2witZosRrEtIa5Zy75Zf6xZdiFtXKCC9qATTZfd++8ecwXPnzpk5Z87znDnPmfcr2XDvzDnPeebM8J3nfp9f5u4CAKRrbtoVAADUQyAHgMQRyAEgcQRyAEgcgRwAErd5Ghfdtm2b79q1axqXBoBkraysfMfdtw8+PpVAvmvXLi0vL0/j0gCQLDM7MuxxUisAkDgCOQAkjkAOAIkjkANA4gjkAJA4AjkAJI5ADgA1rRxZ1Z59h7VyZHUq15/KOHIA6IqVI6u64rr9OnZ8TVs2z+nG9+zW+Tu3NloHWuQAUMP+h5/UseNrWnPp2eNr2v/wk43XgUAOADXsPutUbdk8p00mnbR5TrvPOrXxOpBaAYCSVo6sav/DT2r3Wac+lz45f+dW3fie3c89Lkl79h1ed0xsBHIAKGFULvz8nVt1/s6tU8uXk1oBgBLK5MKnlS8nkANACWVy4dPKl5u7N3KhvIWFBWcZWwCpGZYjn+SYSZnZirsvDD5OjhwASurnwuseExqpFQBIHIEcABJHIAeAxBHIASBxBHIAmMC0VzzMY9QKAFTUhhUP82iRA0BFbVjxMC9YIDezTWb2n2a2N1SZANBGk87gjJWOCZlaeb+kQ5JeFLBMAGidwRUPy6RVYqZjgrTIzWyHpF+UdF2I8gCg7c7fuVVXXnx26WAcMx0TKrXyp5J+X9Ja0QFmtmhmy2a2/MQTTwS6LACkIeaCWrVTK2Z2maTH3X3FzC4qOs7dlyQtSb1Fs+peFwBSMkk6pqwQOfILJb3ZzC6V9DxJLzKzv3f33whQNgB0RqwFtWqnVtz9Q+6+w913Sbpc0lcI4gDQHMaRA0Digs7sdPc7JN0RskwASFXMTSbymKIPABE0OY2f1AoARNDkNH4COQBE0ORGzKRWACCCmOPGBxHIASStqQ7FSTS1ETOBHECy6nYotvlLoAoCOYBkDetQLBuQ27Y5RB10dgJIVp0OxbZtDlEHLXIAyRrsUJSkPfsOl0qV9L8Enj2+Fn1USWzm3vxChAsLC768vNz4dQE0r6k89CSpktRy5Ga24u4Lg4/TIgcQTZN56Eny5U2NKomNHDmAaJrMQzc5AadtaJEDkBQnzRAyDz2ufnUn4BSVn0L6hRw5gKgpkBCBMHaKpqj8tg1RLMqRk1oBEDUFUnWT4lD1Wzmyqj37DmvlyOrE5acyRJHUCoDWD8WrWr+qLemi8tt+X/pIrQCQFDZHHCOvXKXMPfsO69rbHtSaS5tM+sAbztGVF589UfltypEXpVYI5AAKTTo2e9rrn/Tr0G9JX33ZuVp95lgrgnEdjCMHUNkkY7NHnZMP0v1j88E1VOdifgTL1pO36Jq9B1vTYRkDgRxAoUlyxEXn5IP05jmTzHT8xPrgWmcRrEH9yT579h0OVmZbEcgBFJpkbHbROeuC9AmX5HKtD64xOhdT6bCso3aO3MyeJ+lOST+i3hfD59z9I6POIUcOzJ583npT1iI/caIXXPPpjml3lLZZtM5OMzNJL3D3p83sJEn/Jun97r6/6BwCOTCbxuXIRx2fcgAOJVpnp/e+CZ7Ofj0p+9f8UBgAU1M24A4uUjXq2LbNqmyzIDlyM9skaUXS2ZL2uPtdQ45ZlLQoSfPz8yEuC6AFYgXckB2fXRdkir67n3D3n5K0Q9IFZvbKIccsufuCuy9s3749xGUBtECsaezTXM2wyvT+Ngg6asXdnzKzOyRdIumBkGUDaKdYo0JGjZiJmTtPMaVTO5Cb2XZJz2ZB/PmSXi/p47VrBiAJdZePHVf2YHmxA22KKZ0QLfLTJX0my5PPSfpHd98boFwAkYVq2Ta5007sQJviuPMQo1buk/TqAHUB0KAUUwhS/EAb8y+MWJjZCcyoFFMIUjOBNrW9PAnkwIwa1bJt+0Sc1AJtbARyYEYVtWynkXJp+xdH2xHIgRk2rGXbdMol1Vx9m7BnJ4B1mp6Ik8q+mG1GixxIXOi0RNOjNpoe7tfFNA5bvQEJ60paoqngmvr9Yqs3oINSHUI4qKlRKF25X4PIkQMtU2XBpmkuLJWirt4vUitAi0y6a33Xcr4xpXy/SK0ACZjkT38mx1TTxftFagVoka7+6Y+4aJEDLRJy6F/KKQRUQyAHWibEn/6pD7NDNaRWgA5ituRsIZADHUSufbaQWgE6KMXNETA5AjnQoCY7IGMPs6MztT0I5EBDutQB2aXX0gXkyIGGdKkDskuvpQsI5EAEw9ZL6VIHZJdeSxfUXmvFzM6UdIOk0yStSVpy90+OOoe1VtBlo9IOXcord+m1pCLmWivHJf2Oux8ws1MkrZjZ7e7+tQBlA8kZtV5Kl9b56NJrSV3t1Iq7P+ruB7Kfvy/pkKQz6pYLpCpk2qHKkraYXUFHrZjZLkmvlnTXkOcWJS1K0vz8fMjLAq1Sdgz3uNQES9qirGCB3MxeKOlmSVe5+/cGn3f3JUlLUi9HHuq6QCghg+C4tEOZIF11SdtZyc1joyCB3MxOUi+I3+jut4QoE2hS0+OiywTpqpsSF5XJmO/uqx3IzcwkfVrSIXf/RP0qAc1rei/HMkG66jT7ojK7uk8lfihEi/xCSe+QdL+Z3ZM99ofufmuAsoFGVG391lU2SFcZGVJUZtOvDc1jz07MhDI54pTyyFXrGuq1pXSPuog9OzGzyuaIUxkXPUnOm80quo0p+ui8rq0LMq3X07X72CW0yNFZ/TTA1pO3dCpHPK2cN7n29iJHjk4aTANcfdm5Wn3mWNK53Xx+WtJUctXkyKeLHDlmymAaYPWZY7ry4rM3HJdKYBqWnx72emJLpR9h1hDI0Ull0gApdd4xFhyjEMjRSWXGaacUHMlPYxQCOTprXBogpeAYcjPlVNJJKI/OTsy0WQtqKaWTsBGdncAQs9Z5l1I6CeUxIQhogaY2kGCvzW6iRQ4EMmmapsl0R8hcO9qDQA4MMcmiVJMG46bTHbOWTpoFBHJgwCRBuU4wTmn0DNqJQI6ZM661PUlQrhOMSXegLgI5ZkqZ1vYkQbluMCbdgToI5JgpZVrbkwZlgjGmhUCOmVK2tU1QRkoI5Jgp5KPRRQRyRNe2afC0ttE1BHJExdoeQHxBpuib2fVm9riZPRCiPHRHyH0em5rGDqQmVIv8byV9StINgcpDR4Sa7ELLHigWJJC7+51mtitEWeiWUJ2Lo4YNti0HDzStsRy5mS1KWpSk+fn5pi6LFgjRuVjUsqelDjQYyN19SdKS1NtYoqnrohuKWvasrw0wagUJGdayZ8EpgECOxDHBBwgUyM3sJkkXSdpmZkclfcTdPx2ibGAcJvhg1oUatfL2EOUAAKpjz04ASByBHAASRyBHp7RxGn8b64RuYdQKOqOJyUFNbsoMlEUgR+uUCZbDjok9OWhYUO5fN+T+n0BVBHK0SpkWbNExsSYH9b80HnnqB+uC8s0HjuqWA0eD7/8JVEUgR6uUacEWHRNjclD+S2PznGnzpjmdONELyiZF2/8TqIJAjlYp04IddUzoyUH5L40Ta65fu+BMnfGS5z93zZsPHGX/T0wdgRxDTWtp2DIt2CZbuYNfGr9y3o5116O1jTYw9+YXIlxYWPDl5eXGr4tyGGmxHuudoy3MbMXdFwYfp0WODRhpsR6pEbQdE4KwQT+dsMnESAsgAbTIZ0SV9EBTOWhSFkAYBPKETBr4Jsl5x0on9F/D1pO36Jq9B8fWiWAPjEcgT0SdDsgQOe8QATX/GubMtOa+bnLNYPl0ugLlEMgTUScY151dGCqg5l+D3DU3ZzK5Ns2ZPrdyVMdPrC+fTlegHAJ5IuoE47o571ABdfA1XH3ZuVp95pgeeeoHuuk//ntD+UxvB8ohkCeibDAuSoHUyXmHCqhFr2HlyOrQGZJMbwfKYUJQB1TtQKxzjVgBlU5NYDwmBHXUqA7EkDnl2JNimHQDTI5AnriiDkRyysDsCBLIzewSSZ+UtEnSde7+RyHKxXhFHYikKIDZUTuQm9kmSXsk/byko5LuNrMvuPvX6paN8egQBBCiRX6BpMPu/rAkmdlnJb1FEoG8IW3JL9NhCUxHiEB+hqRv5X4/KumnBw8ys0VJi5I0Pz8f4LJoE2ZhAtMTYvVDG/LYhjGN7r7k7gvuvrB9+/YAl0WbDJs0BKAZIQL5UUln5n7fIemRAOUiISx9C0xPiNTK3ZJeYWYvl/RtSZdL+vUA5SIhdLoC01M7kLv7cTN7n6QvqTf88Hp3P1i7ZkjCYAdnfto9QR1oRpBx5O5+q6RbQ5SFdBR1cNLxCTSLrd4wsaIOTjo+gWYRyDts5ciq9uw7rJUjq1HKL+rgpOMTaBarH3ZUU+mNolw4OXIgPFY/nDGjNoMIGWSLZpW2ZbYpMAsI5B1VtBkEHZFA9xDIO6poXHdbNmIGEA6BvAHTCnzD0htt2YgZQDgE8sjaFvjashEzgHAI5JHFCnx1Wvlt2IgZQDgE8shiBL5ptvJZUwVoHwJ5ZDEC37TTGwwtBNqFQN6A0IGP9AaAPAL5BKY9/C5WemParwvAZAjkFbVlFEroVn5bXheA6lg0q6KuruzX1dcFzAICeUVdXdmvq68LmAWsfjiBruaSu/q6gK5g9cOAujT8rmirNgDpIJDPMDo4gW4gR95CsXf26aODE+gGWuQt02QrmYlFQDfUapGb2dvM7KCZrZnZhgQ8qmuyldyfWPSBN5xDWgVIWN0W+QOSflnSXweoC9R8K5kOTiB9tQK5ux+SJDMLUxtsmH4vSXv2HWZIIIBCjeXIzWxR0qIkzc/PN3XZ56Q0RrrfSmZUCYAyxgZyM/uypNOGPPVhd/982Qu5+5KkJak3Iah0DQNoKiCG/rKY9nK1ANIwNpC7++ubqEhMTQTEGF8WjCoBUMZMDD9sIiDG+LJgNx4AZdQK5Gb2Vkl/Lmm7pH8xs3vc/ReC1CygJgJiiN3ph9WPUSUAxpnJRbNG5bLr5LknPZdOTQBlsGhWZlTQrBtQ863nKkGdTk0AdczcWiujZk5OMqty2Loo/S+Ea297UFdct3/smimsBQ6gjplrkY/KZVfNcxe14Ku2sOnUBFBHsoF80nz0qKBZNaAWBeyyXwisBQ4ghKQCeT/wbT15i67ZezBILnvUc+O+LIoCdpkvBDo4AYSSTCDPB745M625T32Cz7jW/ag60cEJIJRkAnk+8Mldc3Mmk099gs+kKRFmbQIIJZlAPhj4rr7sXK0+c6y1E3zGoYMTQChJTQhqegXD/PUkEXQBTFUnJgQ1PbKD5WQBpGDmJgRNgk2KAbQZgbwEZl4CaLOkUivTErJjMqWdigCkoXOBPFagDJGfJ9cOIIZOBfK2B0omAQGIoVM58rZ3SpJrBxBDp1rkbZ8tySQgADEkNSGoCBN3AMyCTkwIGmZYXvzKi8+edrUAoDHJ58jbnhcHgNiSD+SjOhCHbcMGAF1TK7ViZn8i6U2Sjkl6SNJvu/tTISpWVlEHYtuHIgJAKHVb5LdLeqW7v0rS1yV9qH6Vqjt/51ZdefHZ6wJ12ZQLrXYAqavVInf323K/7pf0q/WqE06ZoYi02gF0QchRK++S9A9FT5rZoqRFSZqfnw942eHKjNlmpiWALhgbyM3sy5JOG/LUh93989kxH5Z0XNKNReW4+5KkJak3jnyi2lY0bn2Utk8gAoAyxgZyd3/9qOfN7J2SLpP0Op/G7KIamGkJoAvqjlq5RNIfSHqtuz8TpkrNanrXIQAIre6olU9JOkXS7WZ2j5n9VYA6AQAqqDtqhbnwADBlyc/sBIBZRyAHgMQRyAEgcQRyAEjcVDaWMLMnJB2Z8PRtkr4TsDqhUK9qqFc11KuattZLqle3ne6+ffDBqQTyOsxsedgOGdNGvaqhXtVQr2raWi8pTt1IrQBA4gjkAJC4FAP50rQrUIB6VUO9qqFe1bS1XlKEuiWXIwcArJdiixwAkEMgB4DEtTKQm9nbzOygma2ZWeEwHTO7xMweNLPDZvbB3OMvNbPbzewb2X+DrFNbplwzOydbCbL/73tmdlX23EfN7Nu55y5tql7Zcd80s/uzay9XPT9GvczsTDPbZ2aHsvf8/bnngt6vos9L7nkzsz/Lnr/PzM4re27kel2R1ec+M/uqmf1k7rmh72lD9brIzL6be3+uLntu5Hr9Xq5OD5jZCTN7afZclPtlZteb2eNm9kDB83E/W+7eun+SflzSOZLukLRQcMwmSQ9JOkvSFkn3SvqJ7Lk/lvTB7OcPSvp4oHpVKjer4/+oN4hfkj4q6Xcj3K9S9ZL0TUnb6r6ukPWSdLqk87KfT1FvE+/++xjsfo36vOSOuVTSFyWZpN2S7ip7buR6vUbS1uznN/brNeo9baheF0naO8m5Mes1cPybJH2lgfv1c5LOk/RAwfNRP1utbJG7+yF3f3DMYRdIOuzuD7v7MUmflfSW7Lm3SPpM9vNnJP1SoKpVLfd1kh5y90lnsZZV9/VO7X65+6PufiD7+fuSDkk6I9D180Z9XvL1vcF79kt6iZmdXvLcaPVy96+6+2r2635JOwJdu1a9Ip0buuy3S7op0LULufudkv53xCFRP1utDOQlnSHpW7nfj+qHAeBH3f1RqRcoJL0s0DWrlnu5Nn6I3pf9aXV9qBRGhXq5pNvMbMV6m2FXPT9WvSRJZrZL0qsl3ZV7ONT9GvV5GXdMmXNj1ivv3eq17PqK3tOm6vUzZnavmX3RzM6teG7MesnMTpZ0iaSbcw/Hul/jRP1s1dpYog4rsanzuCKGPFZ7LOWoelUsZ4ukN0v6UO7hv5T0MfXq+TFJ10p6V4P1utDdHzGzl6m3q9N/ZS2JiQW8Xy9U73+4q9z9e9nDE9+vYZcY8tjg56XomCiftTHX3Hig2cXqBfKfzT0c/D2tUK8D6qUNn876L/5Z0itKnhuzXn1vkvTv7p5vKce6X+NE/WxNLZD7mE2dSzgq6czc7zskPZL9/JiZne7uj2Z/vjweol5mVqXcN0o64O6P5cp+7mcz+xtJe5usl7s/kv33cTP7J/X+rLtTU75fZnaSekH8Rne/JVf2xPdriFGfl3HHbClxbsx6ycxeJek6SW909yf7j494T6PXK/eFK3e/1cz+wsy2lTk3Zr1yNvxFHPF+jRP1s5VyauVuSa8ws5dnrd/LJX0he+4Lkt6Z/fxOSWVa+GVUKXdDbi4LZn1vlTS0hztGvczsBWZ2Sv9nSW/IXX9q98vMTNKnJR1y908MPBfyfo36vOTr+5vZCIPdkr6bpYTKnButXmY2L+kWSe9w96/nHh/1njZRr9Oy909mdoF68eTJMufGrFdWnxdLeq1yn7nI92ucuJ+t0L23If6p9z/tUUn/J+kxSV/KHv8xSbfmjrtUvVEOD6mXkuk/fqqkf5X0jey/Lw1Ur6HlDqnXyep9oF88cP7fSbpf0n3Zm3V6U/VSr1f83uzfwbbcL/XSBJ7dk3uyf5fGuF/DPi+S3ivpvdnPJmlP9vz9yo2YKvqsBbpP4+p1naTV3P1ZHveeNlSv92XXvVe9TtjXtOF+Zb//lqTPDpwX7X6p12h7VNKz6sWudzf52WKKPgAkLuXUCgBABHIASB6BHAASRyAHgMQRyAEgcQRyAEgcgRwAEvf/XQ45Lxw8fEYAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(trX, trY, '.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3. 构建模型\n",
    "首先我们需要构建一个序列模型。我们需要的只是一个简单的链接，因此我们只需要使用一个 Dense 层就够了，然后用线性函数进行激活。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Sequential()\n",
    "# model.add(Dense(input_dim=1, output_dim=1, init='uniform', activation='linear'))\n",
    "model.add(Dense(input_dim=1, units=1, kernel_initializer='uniform', activation='linear'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下面的代码将设置输入数据 x，权重 w 和偏置项 b。然我们来看看具体的初始化工作。如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Linear regression model is initialized with weights w: 1.95, b: 1.10\n"
     ]
    }
   ],
   "source": [
    "if os.path.exists('my_model_weights.h5'):\n",
    "    model.load_weights('my_model_weights.h5')\n",
    "else:\n",
    "    weights = model.layers[0].get_weights()\n",
    "\n",
    "w_init = weights[0][0][0]\n",
    "b_init = weights[1][0]\n",
    "print('Linear regression model is initialized with weights w: %.2f, b: %.2f' % (w_init, b_init)) \n",
    "## Linear regression model is initialized with weight w: -0.03, b: 0.00"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在，我们可以l利用自己构造的数据 trX 和 trY 来训练这个线性模型，其中 trY 是 trX 的3倍。因此，权重 w 的值应该是 3。我们使用简单的梯度下降来作为优化器，均方误差（MSE）作为损失值。如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(optimizer='sgd', loss='mse')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最后，我们使用 fit 函数来输入数据。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/20\n",
      "101/101 [==============================] - 0s 355us/step - loss: 0.5725\n",
      "Epoch 2/20\n",
      "101/101 [==============================] - 0s 89us/step - loss: 0.5520\n",
      "Epoch 3/20\n",
      "101/101 [==============================] - 0s 79us/step - loss: 0.5304\n",
      "Epoch 4/20\n",
      "101/101 [==============================] - 0s 109us/step - loss: 0.5148\n",
      "Epoch 5/20\n",
      "101/101 [==============================] - 0s 89us/step - loss: 0.4987\n",
      "Epoch 6/20\n",
      "101/101 [==============================] - 0s 89us/step - loss: 0.4854\n",
      "Epoch 7/20\n",
      "101/101 [==============================] - 0s 109us/step - loss: 0.4705\n",
      "Epoch 8/20\n",
      "101/101 [==============================] - 0s 99us/step - loss: 0.4577\n",
      "Epoch 9/20\n",
      "101/101 [==============================] - 0s 89us/step - loss: 0.4454\n",
      "Epoch 10/20\n",
      "101/101 [==============================] - 0s 89us/step - loss: 0.4328\n",
      "Epoch 11/20\n",
      "101/101 [==============================] - 0s 89us/step - loss: 0.4185\n",
      "Epoch 12/20\n",
      "101/101 [==============================] - 0s 89us/step - loss: 0.4083\n",
      "Epoch 13/20\n",
      "101/101 [==============================] - 0s 69us/step - loss: 0.3999\n",
      "Epoch 14/20\n",
      "101/101 [==============================] - 0s 69us/step - loss: 0.3915\n",
      "Epoch 15/20\n",
      "101/101 [==============================] - 0s 69us/step - loss: 0.3822\n",
      "Epoch 16/20\n",
      "101/101 [==============================] - 0s 79us/step - loss: 0.3715\n",
      "Epoch 17/20\n",
      "101/101 [==============================] - 0s 79us/step - loss: 0.3636\n",
      "Epoch 18/20\n",
      "101/101 [==============================] - 0s 89us/step - loss: 0.3559\n",
      "Epoch 19/20\n",
      "101/101 [==============================] - 0s 69us/step - loss: 0.3483\n",
      "Epoch 20/20\n",
      "101/101 [==============================] - 0s 79us/step - loss: 0.3419\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.callbacks.History at 0x21dbe4e4648>"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(trX, trY, epochs=20, verbose=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "注：与线性回归进行对比，可以看出线性回归的优点：有显式公式，不用训练。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在经过训练之后，我们再次打印权重："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Linear regression model is trained to have weight w: 2.38, b: 1.14\n"
     ]
    }
   ],
   "source": [
    "weights = model.layers[0].get_weights()\n",
    "w_final = weights[0][0][0]\n",
    "b_final = weights[1][0]\n",
    "print('Linear regression model is trained to have weight w: %.2f, b: %.2f' % (w_final, b_final))\n",
    "##Linear regression model is trained to have weight w: 2.94, b: 0.08"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "正如你所看到的，在运行 200 轮之后，现在权重非常接近于 3。你可以将运行的轮数修改为区间 [100, 300] 之间，然后观察输出结构有什么变化。现在，你已经学会了利用很少的代码来构建一个线性回归模型，如果要构建一个相同的模型，在 TensorFlow 中需要用到更多的代码。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_weights(\"my_model_weights.h5\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
